{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/rinongal/stylegan-nada/blob/main/stylegan_nada.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bYsd0_RFXb04"
   },
   "source": [
    "# Welcome to StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QTHeOO8qFw_e"
   },
   "source": [
    "# Step 1: Setup required libraries and models. \n",
    "This may take a few minutes.\n",
    "\n",
    "You may optionally enable downloads with pydrive in order to authenticate and avoid drive download limits when fetching pre-trained ReStyle and StyleGAN2 models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "ph3R7lbl_arQ"
   },
   "outputs": [],
   "source": [
    "#@title Setup\n",
    "%tensorflow_version 1.x\n",
    "\n",
    "import os\n",
    "\n",
    "from pydrive.auth import GoogleAuth\n",
    "from pydrive.drive import GoogleDrive\n",
    "from google.colab import auth\n",
    "from oauth2client.client import GoogleCredentials\n",
    "\n",
    "pretrained_model_dir = os.path.join(\"/content\", \"models\")\n",
    "os.makedirs(pretrained_model_dir, exist_ok=True)\n",
    "\n",
    "restyle_dir = os.path.join(\"/content\", \"restyle\")\n",
    "stylegan_ada_dir = os.path.join(\"/content\", \"stylegan_ada\")\n",
    "stylegan_nada_dir = os.path.join(\"/content\", \"stylegan_nada\")\n",
    "\n",
    "output_dir = os.path.join(\"/content\", \"output\")\n",
    "\n",
    "output_model_dir = os.path.join(output_dir, \"models\")\n",
    "output_image_dir = os.path.join(output_dir, \"images\")\n",
    "\n",
    "download_with_pydrive = True #@param {type:\"boolean\"}    \n",
    "    \n",
    "class Downloader(object):\n",
    "    def __init__(self, use_pydrive):\n",
    "        self.use_pydrive = use_pydrive\n",
    "\n",
    "        if self.use_pydrive:\n",
    "            self.authenticate()\n",
    "        \n",
    "    def authenticate(self):\n",
    "        auth.authenticate_user()\n",
    "        gauth = GoogleAuth()\n",
    "        gauth.credentials = GoogleCredentials.get_application_default()\n",
    "        self.drive = GoogleDrive(gauth)\n",
    "    \n",
    "    def download_file(self, file_id, file_dst):\n",
    "        if self.use_pydrive:\n",
    "            downloaded = self.drive.CreateFile({'id':file_id})\n",
    "            downloaded.FetchMetadata(fetch_all=True)\n",
    "            downloaded.GetContentFile(file_dst)\n",
    "        else:\n",
    "            !gdown --id $file_id -O $file_dst\n",
    "\n",
    "downloader = Downloader(download_with_pydrive)\n",
    "\n",
    "# install requirements\n",
    "!git clone https://github.com/yuval-alaluf/restyle-encoder.git $restyle_dir\n",
    "\n",
    "downloader.download_file(\"1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE\", os.path.join(pretrained_model_dir, \"restyle_psp_ffhq_encode.pt\"))\n",
    "downloader.download_file(\"1e2oXVeBPXMQoUoC_4TNwAWpOPpSEhE_e\", os.path.join(pretrained_model_dir, \"restyle_e4e_ffhq_encode.pt\"))\n",
    "\n",
    "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
    "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
    "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force\n",
    "\n",
    "!pip install ftfy regex tqdm \n",
    "!pip install git+https://github.com/openai/CLIP.git\n",
    "\n",
    "!git clone https://github.com/NVlabs/stylegan2-ada/ $stylegan_ada_dir\n",
    "!git clone https://github.com/rinongal/stylegan-nada.git $stylegan_nada_dir\n",
    "\n",
    "from argparse import Namespace\n",
    "\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "sys.path.append(restyle_dir)\n",
    "sys.path.append(stylegan_nada_dir)\n",
    "sys.path.append(os.path.join(stylegan_nada_dir, \"ZSSGAN\"))\n",
    "\n",
    "from restyle.utils.common import tensor2im\n",
    "from restyle.models.psp import pSp\n",
    "from restyle.models.e4e import e4e\n",
    "\n",
    "device = 'cuda'\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kSL166pfGRWF"
   },
   "source": [
    "# Step 2: Choose a model type.\n",
    "Model will be downloaded and converted to a pytorch compatible version.\n",
    "\n",
    "Re-runs of the cell with the same model will re-use the previously downloaded version. Feel free to experiment and come back to previous models :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "J4ATNsC1k28g"
   },
   "outputs": [],
   "source": [
    "source_model_type = 'ffhq' #@param['ffhq', 'cat', 'dog', 'church', 'horse', 'car']\n",
    "\n",
    "source_model_download_path = {\"ffhq\":   \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl\",\n",
    "                              \"cat\":    \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl\",\n",
    "                              \"dog\":    \"https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl\",\n",
    "                              \"church\": \"1iDo5cUgbwsJEt2uwfgDy_iPlaT-lLZmi\",\n",
    "                              \"car\":    \"1i-39ztut-VdUVUiFuUrwdsItR--HF81w\",\n",
    "                              \"horse\":  \"1irwWI291DolZhnQeW-ZyNWqZBjlWyJUn\"}\n",
    "\n",
    "model_names = {\"ffhq\":   \"ffhq.pkl\",\n",
    "               \"cat\":    \"afhqcat.pkl\",\n",
    "               \"dog\":    \"afhqdog.pkl\",\n",
    "               \"church\": \"stylegan2-church-config-f.pkl\",\n",
    "               \"car\":    \"stylegan2-car-config-f.pkl\",\n",
    "               \"horse\":  \"stylegan2-horse-config-f.pkl\"}\n",
    "\n",
    "download_string = source_model_download_path[source_model_type]\n",
    "file_name = model_names[source_model_type]\n",
    "pt_file_name = file_name.split(\".\")[0] + \".pt\"\n",
    "\n",
    "dataset_sizes = {\n",
    "    \"ffhq\":   1024,\n",
    "    \"cat\":    512,\n",
    "    \"dog\":    512,\n",
    "    \"church\": 256,\n",
    "    \"horse\":  256,\n",
    "    \"car\":    512,\n",
    "}\n",
    "\n",
    "if not os.path.isfile(os.path.join(pretrained_model_dir, file_name)):\n",
    "    print(\"Downloading chosen model...\")\n",
    "\n",
    "    if download_string.endswith(\".pkl\"):\n",
    "        !wget $download_string -O $pretrained_model_dir/$file_name\n",
    "    else:\n",
    "        downloader.download_file(download_string, os.path.join(pretrained_model_dir, file_name))\n",
    "        \n",
    "if not os.path.isfile(os.path.join(pretrained_model_dir, pt_file_name)):\n",
    "    print(\"Converting sg2 model. This may take a few minutes...\")\n",
    "    \n",
    "    tf_path = next(filter(lambda x: \"tensorflow\" in x, sys.path), None)\n",
    "    py_path = tf_path + f\":{stylegan_nada_dir}/ZSSGAN\"\n",
    "    convert_script = os.path.join(stylegan_nada_dir, \"convert_weight.py\")\n",
    "    !PYTHONPATH=$py_path python $convert_script --repo $stylegan_ada_dir --gen $pretrained_model_dir/$file_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DAri8ULOG2VE"
   },
   "source": [
    "# Step 3: Train the model.\n",
    "Describe your source and target class. These describe the direction of change you're trying to apply (e.g. \"photo\" to \"sketch\", \"dog\" to \"the joker\" or \"dog\" to \"avocado dog\").\n",
    "\n",
    "We reccomend leaving the 'improve shape' button unticked at first, as it will lead to an increase in running times and is often not needed.\n",
    "For more drastic changes, turn it on and increase the number of iterations.\n",
    "\n",
    "As a rule of thumb:\n",
    "- Style and minor domain changes ('photo' -> 'sketch') require ~200-400 iterations.\n",
    "- Identity changes ('person' -> 'taylor swift') require ~150-200 iterations.\n",
    "- Simple in-domain changes ('face' -> 'smiling face') may require as few as 50.\n",
    "\n",
    "> Updates:\n",
    "> 03/08 - Added support for saving model checkpoints. If you want to save, set save_interval > 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "8YrtPb7KF8m-"
   },
   "outputs": [],
   "source": [
    "from ZSSGAN.model.ZSSGAN import ZSSGAN\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "from tqdm import notebook\n",
    "\n",
    "from ZSSGAN.utils.file_utils import save_images\n",
    "from ZSSGAN.utils.training_utils import mixing_noise\n",
    "\n",
    "from IPython.display import display\n",
    "\n",
    "source_class = \"Photo\" #@param {\"type\": \"string\"}\n",
    "target_class = \"Sketch\" #@param {\"type\": \"string\"}\n",
    "\n",
    "improve_shape = False #@param{type:\"boolean\"}\n",
    "\n",
    "model_choice = [\"ViT-B/32\", \"ViT-B/16\"]\n",
    "model_weights = [1.0, 0.0]\n",
    "\n",
    "if improve_shape:\n",
    "    model_weights[1] = 1.0\n",
    "    \n",
    "mixing = 0.9 if improve_shape else 0.0\n",
    "\n",
    "auto_layers_k = int(2 * (2 * np.log2(dataset_sizes[source_model_type]) - 2) / 3) if improve_shape else 0\n",
    "auto_layer_iters = 1 if improve_shape else 0\n",
    "\n",
    "training_iterations = 151 #@param {type: \"integer\"}\n",
    "output_interval     = 50 #@param {type: \"integer\"}\n",
    "save_interval       = 0 #@param {type: \"integer\"}\n",
    "\n",
    "training_args = {\n",
    "    \"size\": dataset_sizes[source_model_type],\n",
    "    \"batch\": 2,\n",
    "    \"n_sample\": 4,\n",
    "    \"output_dir\": output_dir,\n",
    "    \"lr\": 0.002,\n",
    "    \"frozen_gen_ckpt\": os.path.join(pretrained_model_dir, pt_file_name),\n",
    "    \"train_gen_ckpt\": os.path.join(pretrained_model_dir, pt_file_name),\n",
    "    \"iter\": training_iterations,\n",
    "    \"source_class\": source_class,\n",
    "    \"target_class\": target_class,\n",
    "    \"lambda_direction\": 1.0,\n",
    "    \"lambda_patch\": 0.0,\n",
    "    \"lambda_global\": 0.0,\n",
    "    \"lambda_texture\": 0.0,\n",
    "    \"lambda_manifold\": 0.0,\n",
    "    \"auto_layer_k\": auto_layers_k,\n",
    "    \"auto_layer_iters\": auto_layer_iters,\n",
    "    \"auto_layer_batch\": 8,\n",
    "    \"output_interval\": 50,\n",
    "    \"clip_models\": model_choice,\n",
    "    \"clip_model_weights\": model_weights,\n",
    "    \"mixing\": mixing,\n",
    "    \"phase\": None,\n",
    "    \"sample_truncation\": 0.7,\n",
    "    \"save_interval\": save_interval,\n",
    "    \"target_img_list\": None\n",
    "}\n",
    "\n",
    "args = Namespace(**training_args)\n",
    "\n",
    "print(\"Loading base models...\")\n",
    "net = ZSSGAN(args)\n",
    "print(\"Models loaded! Starting training...\")\n",
    "\n",
    "g_reg_ratio = 4 / 5\n",
    "\n",
    "g_optim = torch.optim.Adam(\n",
    "    net.generator_trainable.parameters(),\n",
    "    lr=args.lr * g_reg_ratio,\n",
    "    betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),\n",
    ")\n",
    "\n",
    "# Set up output directories.\n",
    "sample_dir = os.path.join(args.output_dir, \"sample\")\n",
    "ckpt_dir   = os.path.join(args.output_dir, \"checkpoint\")\n",
    "\n",
    "os.makedirs(sample_dir, exist_ok=True)\n",
    "os.makedirs(ckpt_dir, exist_ok=True)\n",
    "\n",
    "seed = 3 #@param {\"type\": \"integer\"}\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "# Training loop\n",
    "fixed_z = torch.randn(args.n_sample, 512, device=device)\n",
    "\n",
    "for i in notebook.tqdm(range(args.iter)):\n",
    "    net.train()\n",
    "        \n",
    "    sample_z = mixing_noise(args.batch, 512, args.mixing, device)\n",
    "\n",
    "    [sampled_src, sampled_dst], clip_loss = net(sample_z)\n",
    "\n",
    "    net.zero_grad()\n",
    "    clip_loss.backward()\n",
    "\n",
    "    g_optim.step()\n",
    "\n",
    "    if i % output_interval == 0:\n",
    "        net.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            [sampled_src, sampled_dst], loss = net([fixed_z], truncation=args.sample_truncation)\n",
    "\n",
    "            if source_model_type == 'car':\n",
    "                sampled_dst = sampled_dst[:, :, 64:448, :]\n",
    "\n",
    "            grid_rows = 4\n",
    "\n",
    "            save_images(sampled_dst, sample_dir, \"dst\", grid_rows, i)\n",
    "\n",
    "            img = Image.open(os.path.join(sample_dir, f\"dst_{str(i).zfill(6)}.jpg\")).resize((1024, 256))\n",
    "            display(img)\n",
    "    \n",
    "    if (args.save_interval > 0) and (i > 0) and (i % args.save_interval == 0):\n",
    "        torch.save(\n",
    "            {\n",
    "                \"g_ema\": net.generator_trainable.generator.state_dict(),\n",
    "                \"g_optim\": g_optim.state_dict(),\n",
    "            },\n",
    "            f\"{ckpt_dir}/{str(i).zfill(6)}.pt\",\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9ZZk6yZQvxGY"
   },
   "source": [
    "# Step 4: Generate samples with the new model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "dLinyTgev5Qk"
   },
   "outputs": [],
   "source": [
    "truncation = 0.7 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
    "\n",
    "samples = 9\n",
    "\n",
    "with torch.no_grad():\n",
    "    net.eval()\n",
    "    sample_z = torch.randn(samples, 512, device=device)\n",
    "\n",
    "    [sampled_src, sampled_dst], loss = net([sample_z], truncation=truncation)\n",
    "\n",
    "    if source_model_type == 'car':\n",
    "        sampled_dst = sampled_dst[:, :, 64:448, :]\n",
    "\n",
    "    grid_rows = int(samples ** 0.5)\n",
    "\n",
    "    save_images(sampled_dst, sample_dir, \"sampled\", grid_rows, 0)\n",
    "\n",
    "    display(Image.open(os.path.join(sample_dir, f\"sampled_{str(0).zfill(6)}.jpg\")).resize((768, 768)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e4hVHBrlGxzo"
   },
   "source": [
    "## Editing a real image with Re-Style inversion (currently only FFHQ inversion is supported):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jv4h9WZ2gZio"
   },
   "source": [
    "Step 1: Set up the Re-Style model.\n",
    "\n",
    "Choose the pSp model for better reconstructions. \n",
    "We found that for some extreme modifications (typically those which require 500+ training iterations), e4e may perform better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AcSXQ3AbPrqh",
    "outputId": "a3151aa8-ced3-41c1-990c-faf7cd40f1ac"
   },
   "outputs": [],
   "source": [
    "encoder_type = 'psp' #@param['psp', 'e4e']\n",
    "\n",
    "restyle_experiment_args = {\n",
    "    \"model_path\": os.path.join(pretrained_model_dir, f\"restyle_{encoder_type}_ffhq_encode.pt\"),\n",
    "    \"transform\": transforms.Compose([\n",
    "        transforms.Resize((256, 256)),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "}\n",
    "\n",
    "model_path = restyle_experiment_args['model_path']\n",
    "ckpt = torch.load(model_path, map_location='cpu')\n",
    "\n",
    "opts = ckpt['opts']\n",
    "\n",
    "opts['checkpoint_path'] = model_path\n",
    "opts = Namespace(**opts)\n",
    "\n",
    "restyle_net = (pSp if encoder_type == 'psp' else e4e)(opts)\n",
    "\n",
    "restyle_net.eval()\n",
    "restyle_net.cuda()\n",
    "print('Model successfully loaded!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HfB-jTnZgn0D"
   },
   "source": [
    "Step 2: Align and invert an image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "2tMd5WBvE0Ol"
   },
   "outputs": [],
   "source": [
    "def run_alignment(image_path):\n",
    "    import dlib\n",
    "    from scripts.align_faces_parallel import align_face\n",
    "    if not os.path.exists(\"shape_predictor_68_face_landmarks.dat\"):\n",
    "        print('Downloading files for aligning face image...')\n",
    "        os.system('wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        os.system('bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2')\n",
    "        print('Done.')\n",
    "    predictor = dlib.shape_predictor(\"shape_predictor_68_face_landmarks.dat\")\n",
    "    aligned_image = align_face(filepath=image_path, predictor=predictor) \n",
    "    print(\"Aligned image has shape: {}\".format(aligned_image.size))\n",
    "    return aligned_image \n",
    "\n",
    "image_path = \"/content/ariana.jpg\" #@param {'type': 'string'}\n",
    "original_image = Image.open(image_path).convert(\"RGB\")\n",
    "\n",
    "input_image = run_alignment(image_path)\n",
    "\n",
    "display(input_image)\n",
    "\n",
    "img_transforms = restyle_experiment_args['transform']\n",
    "transformed_image = img_transforms(input_image)\n",
    "\n",
    "def get_avg_image(net):\n",
    "    avg_image = net(net.latent_avg.unsqueeze(0),\n",
    "                    input_code=True,\n",
    "                    randomize_noise=False,\n",
    "                    return_latents=False,\n",
    "                    average_code=True)[0]\n",
    "    avg_image = avg_image.to('cuda').float().detach()\n",
    "    return avg_image\n",
    "\n",
    "opts.n_iters_per_batch = 5\n",
    "opts.resize_outputs = False  # generate outputs at full resolution\n",
    "\n",
    "from restyle.utils.inference_utils import run_on_batch\n",
    "\n",
    "with torch.no_grad():\n",
    "    avg_image = get_avg_image(restyle_net)\n",
    "    result_batch, result_latents = run_on_batch(transformed_image.unsqueeze(0).cuda(), restyle_net, opts, avg_image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XOiIZcJUgsQS"
   },
   "source": [
    "Step 3: Convert the image to the new domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "u5JqEOMnEA_m"
   },
   "outputs": [],
   "source": [
    "#@title Convert inverted image.\n",
    "inverted_latent = torch.Tensor(result_latents[0][4]).cuda().unsqueeze(0).unsqueeze(1)\n",
    "\n",
    "with torch.no_grad():\n",
    "    net.eval()\n",
    "    \n",
    "    [sampled_src, sampled_dst] = net(inverted_latent, input_is_latent=True)[0]\n",
    "    \n",
    "    joined_img = torch.cat([sampled_src, sampled_dst], dim=0)\n",
    "    save_images(joined_img, sample_dir, \"joined\", 2, 0)\n",
    "    display(Image.open(os.path.join(sample_dir, f\"joined_{str(0).zfill(6)}.jpg\")).resize((512, 256)))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyM+XErRWoKRUgyMSyUOOQav",
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "stylegan_nada.ipynb",
   "provenance": []
  },
  "interpreter": {
   "hash": "fd69f43f58546b570e94fd7eba7b65e6bcc7a5bbc4eab0408017d18902915d69"
  },
  "kernelspec": {
   "display_name": "Python 3.7.5 64-bit",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
